What is the Issue?

If you have any experience with PyTorch training loops you might have noticed it is just annoying to write and also probably way to complicated, especially if you are used to the Sklearn or Keras fit function.

For visualization here is an example of a PyTorch training loop:

for epoch in range(epochs):
    running_loss = 0.
    for data in train_dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("Epoch {} loss: {}".format(epoch + 1, running_loss / len(train_dataloader)))

In comparison this is what you would do in Keras:

model.compile(optimizer, loss)
model.fit(inputs, labels, batch_size, epochs)

What is Ignite?

Ignite is a high-level library to help with training neural networks in PyTorch.

ignite vs pytorch

Ignite Core Elements

Before you can use it, it has to be installed. You can do a pip install

pip install pytorch-ignite

or use an anaconda installation

conda install ignite -c pytorch

Engine

The main aspect of Ignite is the so called engine. The engine defines the training loop and creates events which will be discussed in the next section. The engine class takes as an argument some kind of process function. You can define this function yourselfes and might look like this:

def process_function(engine, batch):
    inputs, targets = batch
    optimizer = zero_grad()
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

That is your training loop. Now we just have to run it. This can be done by calling the run function on the engine.

engine = Engine(process_function)
engine.run(dataloader)

The dataloader should be a PyTorch DataLoader object. You can for example define one for your train and another one for your validation/test set.

If you have a pretty standard supervised learning task, Ignite offers the functions create_supervised_trainer and create_supervised_evaluator. Both of them take as input your network and optionally some kind of metrics you want to evaluate, the supported device (CPU or GPU) as well as a prepare_batch and a non_blocking argument.

Events

An event occurs during the training process. Events which are being fired by the engine are:

  • COMPLETED
  • EPOCH_COMPLETED
  • EPOCH_STARTED
  • EXCEPTION_RAISED
  • ITERATION_COMPLETED
  • ITERATION_STARTED
  • STARTED

You can use those events for example to perform some printing or logging. A practical usecase is to define Handlers to deal with the given event.

Handlers

Practical Example

After giving insight into the core concepts of the library, I just want to give a short example how it can look like in your project. The following example also utilizes TensorBoardX.

loss = nn.MSELoss()

def create_summary_writer(model, data_loader, log_dir):
    writer = SummaryWriter(log_dir=log_dir)
    data_loader_iter = iter(data_loader)
    x, y = next(data_loader_iter)
    x = x.view(x.shape[0], -1)
    try:
        writer.add_graph(model, x)
    except Exception as e:
        print("Failed to save model graph: {}".format(e))
    return writer


def fit(model, loss, get_dataloader, learning_rate=1e-2, log_interval=10, log_dir="board"):
    train_loader, val_loader = get_dataloader
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    writer = create_summary_writer(model, train_loader, log_dir)
    optimizer = optim.Adam(model.parameters(), learning_rate)

    trainer = create_supervised_trainer(model, optimizer, loss, device=device)
    evaluator = create_supervised_evaluator(model, device=device,
                                            metrics={'loss': ignite.metrics.Loss(loss)})
    
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            writer.add_scalar("training/loss", trainer.state.output, trainer.state.iteration)
    
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        writer.add_scalar("training/avg_loss", metrics['loss'], trainer.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        writer.add_scalar("validation/avg_loss", metrics['loss'], trainer.state.epoch)

    trainer.run(train_loader, max_epochs=epochs)
    writer.close()

References